#!/usr/bin/env python3

import argparse
from enum import IntEnum
from http import HTTPStatus
import http.server
import io
import os
import re
import subprocess
import sys
import textwrap
import time
import threading
from uuid import uuid4

import libvirt
from lxml import etree
from PIL import Image
import qrcode
import yaml
import zbar

QEMU_NS = 'http://libvirt.org/schemas/domain/qemu/1.0'
etree.register_namespace('qemu', QEMU_NS)

OVMF_PATH = '/usr/share/OVMF/OVMF_CODE.fd'


class Verbosity(IntEnum):
    """Verbosity level"""

    PROGRESS = 1
    DEBUG = 2


class HTTPRequestHandler(http.server.SimpleHTTPRequestHandler):
    """HTTP request handler"""

    def log_request(self, code='-', **kwargs):
        """Inhibit successful request logging"""
        if isinstance(code, HTTPStatus):
            code = code.value
        if code == HTTPStatus.OK:
            return
        super().log_request(code=code, **kwargs)


def start_httpd():
    """Start HTTP daemon thread serving static files"""
    httpd = http.server.HTTPServer(('localhost', 0), HTTPRequestHandler)
    threading.Thread(target=httpd.serve_forever, daemon=True).start()
    return httpd.server_address[1]


def ipxe_script(uuid, version, arch, bootmgr, bcd, bootsdi, bootargs):
    """Construct iPXE boot script"""
    script = textwrap.dedent(f"""
    #!ipxe
    kernel ../wimboot {bootargs}
    initrd -n qr.txt qr-{uuid}.txt qr.txt
    initrd ../images/{version}/{arch}/sources/boot.wim boot.wim
    initrd ../winpeshl.ini winpeshl.ini
    """).lstrip()
    if bootmgr:
        script += textwrap.dedent(f"""
        initrd ../images/{version}/{arch}/bootmgr bootmgr
        """).lstrip()
    if bcd:
        script += textwrap.dedent(f"""
        initrd ../images/{version}/{arch}/boot/bcd BCD
        """).lstrip()
    if bootsdi:
        script += textwrap.dedent(f"""
        initrd ../images/{version}/{arch}/boot/boot.sdi boot.sdi
        """).lstrip()
    script += textwrap.dedent(f"""
    boot
    """).lstrip()
    return script


def vm_xml(virttype, name, uuid, memory, uefi, logfile, romfile, booturl):
    """Construct XML description of VM"""
    x_domain = etree.Element('domain')
    x_domain.attrib['type'] = virttype
    x_name = etree.SubElement(x_domain, 'name')
    x_name.text = name
    x_uuid = etree.SubElement(x_domain, 'uuid')
    x_uuid.text = uuid
    x_os = etree.SubElement(x_domain, 'os')
    x_ostype = etree.SubElement(x_os, 'type')
    x_ostype.text = 'hvm'
    x_ostype.attrib['arch'] = 'x86_64'
    if uefi:
        x_loader = etree.SubElement(x_os, 'loader')
        x_loader.text = OVMF_PATH
    x_features = etree.SubElement(x_domain, 'features')
    x_acpi = etree.SubElement(x_features, 'acpi')
    x_boot = etree.SubElement(x_os, 'boot')
    x_boot.attrib['dev'] = 'network'
    x_memory = etree.SubElement(x_domain, 'memory')
    x_memory.text = str(memory)
    x_memory.attrib['unit'] = 'MiB'
    x_devices = etree.SubElement(x_domain, 'devices')
    x_graphics = etree.SubElement(x_devices, 'graphics')
    x_graphics.attrib['type'] = 'spice'
    x_video = etree.SubElement(x_devices, 'video')
    x_video_model = etree.SubElement(x_video, 'model')
    x_video_model.attrib['type'] = 'vga'
    x_interface = etree.SubElement(x_devices, 'interface')
    x_interface.attrib['type'] = 'user'
    x_interface_model = etree.SubElement(x_interface, 'model')
    x_interface_model.attrib['type'] = 'e1000'
    if romfile:
        x_rom = etree.SubElement(x_interface, 'rom')
        x_rom.attrib['file'] = romfile
    x_qemu = etree.SubElement(x_domain, '{%s}commandline' % QEMU_NS)
    for arg in ('-set', 'netdev.%s.bootfile=%s' % ('hostnet0', booturl),
                '-debugcon', 'file:%s' % logfile):
        x_qemu_arg = etree.SubElement(x_qemu, '{%s}arg' % QEMU_NS)
        x_qemu_arg.attrib['value'] = arg
    return etree.tostring(x_domain, pretty_print=True).decode().strip()


def screenshot(vm, screen=0):
    """Take screenshot of VM"""
    stream = vm.connect().newStream()
    vm.screenshot(stream, screen)
    with io.BytesIO() as fh:
        stream.recvAll(lambda s, d, f: f.write(d), fh)
        image = Image.open(fh)
        image.load()
    return image


def qrcodes(image):
    """Get QR codes within an image"""
    zimage = zbar.Image(width=image.width, height=image.height, format='Y800',
                        data=image.convert('RGBA').convert('L').tobytes())
    zbar.ImageScanner().scan(zimage)
    return [x.data for x in zimage]


# Parse command-line arguments
parser = argparse.ArgumentParser(description="Run wimboot test case")
parser.add_argument('--connection', '-c', default='qemu:///session',
                    help="Libvirt connection URI", metavar='URI')
parser.add_argument('--interactive', '-i', action='store_true',
                    help="Launch interactive viewer")
parser.add_argument('--romfile', '-r', metavar='FILE',
                    help="iPXE boot ROM")
parser.add_argument('--timeout', '-t', type=int, default=60, metavar='T',
                    help="Timeout (in seconds)")
parser.add_argument('--verbose', '-v', action='count', default=0,
                    help="Increase verbosity")
parser.add_argument('test', nargs='+', help="YAML test case(s)")
args = parser.parse_args()

# Start HTTP daemon
http_port = start_httpd()

# Open libvirt connection
virt = libvirt.open(args.connection)

# Select a supported virtualisation type
try:
    virt.getDomainCapabilities(virttype='kvm')
    virttype = 'kvm'
except libvirt.libvirtError:
    virttype = 'qemu'

# Run test cases
failures = []
for test_file in args.test:

    # Load test case
    with open(test_file, 'rt') as fh:
        test = yaml.safe_load(fh)
    key = os.path.splitext(test_file)[0]
    name = test.get('name', key)
    version = test['version']
    arch = test['arch']
    uefi = test.get('uefi', False)
    memory = test.get('memory', 2048)
    bootmgr = test.get('bootmgr', False)
    bcd = test.get('bcd', True)
    bootsdi = test.get('boot.sdi', False)
    bootargs = test.get('bootargs', '')
    logcheck = test.get('logcheck', [])

    # Generate test UUID
    uuid = uuid4().hex

    # Construct boot script
    script = ipxe_script(uuid, version, arch, bootmgr, bcd, bootsdi, bootargs)
    if args.verbose >= Verbosity.DEBUG:
        print("%s boot script:\n%s\n" % (name, script.strip()))
    bootfile = 'in/boot-%s.ipxe' % uuid
    with open(bootfile, 'wt') as fh:
        fh.write(script)

    # Generate test QR code
    qr = qrcode.QRCode()
    qr.add_data(uuid)
    qrfile = 'in/qr-%s.txt' % uuid
    with open(qrfile, 'wt', newline='\r\n') as fh:
        qr.print_ascii(out=fh)

    # Construct debug log filename
    logfile = os.path.abspath('out/%s.log' % key)

    # Construct boot ROM filename
    romfile = os.path.abspath(args.romfile) if args.romfile else None

    # Construct boot URL
    booturl = 'http://${next-server}:%s/in/boot-%s.ipxe' % (http_port, uuid)

    # Launch VM
    xml = vm_xml(virttype, name, uuid, memory, uefi, logfile, romfile, booturl)
    if args.verbose >= Verbosity.DEBUG:
        print("%s definition:\n%s\n" % (name, xml))
    vm = virt.createXML(xml, flags=libvirt.VIR_DOMAIN_START_AUTODESTROY)
    if args.verbose >= Verbosity.PROGRESS:
        print("%s launched as ID %d" % (name, vm.ID()))
    if args.verbose >= Verbosity.DEBUG:
        print("%s description:\n%s\n" % (name, vm.XMLDesc().strip()))

    # Launch interactive viewers, if requested
    if args.interactive:
        viewers = [subprocess.Popen(['virt-viewer', '--attach',
                                     '--connect', args.connection,
                                     '--id', str(vm.ID())]),
                   subprocess.Popen(['tail', '-f', logfile])]
    else:
        viewers = []

    # Wait for test to complete
    timeout = time.clock_gettime(time.CLOCK_MONOTONIC) + args.timeout
    passed = False
    aborted = False
    while time.clock_gettime(time.CLOCK_MONOTONIC) < timeout:

        # Sleep for a second
        time.sleep(1)

        # Take screenshot
        image = screenshot(vm)
        image.save('out/%s.png' % key)

        # Abort if a viewer has been closed
        if any(x.poll() is not None for x in viewers):
            print("%s aborted" % name)
            aborted = True
            break

        # Wait for QR code to appear
        if uuid not in qrcodes(image):
            continue

        # Check for required log messages
        with open(logfile, 'rt') as fh:
            log = fh.read()
        logfail = [x for x in logcheck if not re.search(x, log)]
        if logfail:
            print("%s failed log check: %s" % (name, ', '.join(logfail)))
            break

        # Pass test
        if args.verbose >= Verbosity.PROGRESS:
            print("%s passed" % name)
        passed = True
        break

    else:

        # Timeout
        print("%s timed out" % name)

    # Terminate viewers
    for viewer in viewers:
        viewer.terminate()
        viewer.wait()

    # Destroy VM
    vm.destroy()

    # Remove input files
    os.unlink(qrfile)
    os.unlink(bootfile)

    # Record failure, if applicable
    if not passed:
        failures.append(name)

    # Abort testing, if applicable
    if aborted:
        break

# Report any failures
if failures:
    sys.exit("Failures: %s" % ','.join(failures))
